-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better support for classification tasks with large number of label classes #561
Conversation
…e list of labels by similarity to the prompt
# This is the VectorStore class that is used to store the embeddings and do a similarity search over. | ||
VectorStoreWrapper(cache=False), | ||
# This is the number of examples to produce. | ||
k=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have this value of k
be configurable. Maybe 10 is a reasonable default, but we might want workflows in the future that automatically test for the right value of k
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in f97a82d
can now specify value of k in autolabel config file like so:
"label_selection": true,
"label_selection_count": 10
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did we benchmark this on banking/ledgar. I hope there isn't a big drop in performance using this approach
# This is the list of labels available to select from. | ||
label_examples, | ||
# This is the embedding class used to produce embeddings which are used to measure semantic similarity. | ||
OpenAIEmbeddings(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use the embedding model that is the same as the one used for the seed examples, this can be read from the config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to read this from the embedding model section in Autolabel config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in latest commit.
It now chooses embedding function based on config.embedding_provider()
self.label_selector = LabelSelector.from_examples(
labels=self.config.labels_list(),
k=self.config.label_selection_count(),
embedding_func=PROVIDER_TO_MODEL.get(
self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
)(),
)
similar_prompt = FewShotPromptTemplate( | ||
example_selector=example_selector, | ||
example_prompt=example_prompt, | ||
prefix="Input: {example}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use the example template from the config here? see how the seed examples are prepared
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was unnecessary and has been removed in latest commit.
) | ||
label_examples = [{"input": label} for label in labels_list] | ||
|
||
example_selector = SemanticSimilarityExampleSelector.from_examples( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of creating this each time, is it possible to construct this example selector once and then just call sample labels each time. This would make sure that we just embed the label list once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea.
I now construct the LabelSelector once in agent.run() and agent.plan(). This way embeddings of labels are only computed once.
# if large number of labels, filter labels_list by similarity of labels to input | ||
if num_labels >= 50: | ||
example_prompt = PromptTemplate( | ||
input_variables=["input"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: call this label?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed in latest commit
# This is the embedding class used to produce embeddings which are used to measure semantic similarity. | ||
OpenAIEmbeddings(), | ||
# This is the VectorStore class that is used to store the embeddings and do a similarity search over. | ||
VectorStoreWrapper(cache=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
went through the code, ideally we don't need to set this cache as False and can use the cache setting from teh config, but if not, this would still be fine.
@@ -55,6 +60,40 @@ def construct_prompt(self, input: Dict, examples: List) -> str: | |||
# prepare task guideline | |||
labels_list = self.config.labels_list() | |||
num_labels = len(labels_list) | |||
|
|||
# if large number of labels, filter labels_list by similarity of labels to input | |||
if num_labels >= 50: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we do this based on some config setting. just want to make sure that we have the ability to turn this on or off. we can do this on num_labels > 50 if the config setting corresponding to this is not set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with previous comments. we should enable this "label selection" from a config parameter, not a hardcoded num_labels threshold
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in f97a82d
can now turn on/off label selection in config (as well as the number of labels to select), like so:
"label_selection": true,
"label_selection_count": 10
# This is the VectorStore class that is used to store the embeddings and do a similarity search over. | ||
VectorStoreWrapper(cache=False), | ||
# This is the number of examples to produce. | ||
k=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets make this configurable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in latest commit.
My initial test on Ledgar: Roughly the same, but need to continue testing. Will try out a full run on banking and ledgar datasets. |
split_lines = sampled_labels.split("\n") | ||
labels_list = [] | ||
for i in range(1, len(split_lines)): | ||
if split_lines[i]: | ||
labels_list.append(split_lines[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might break for input examples that contain newline characters \n
.
Maybe I do a check that split_lines[i] in labels_list
before appending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be needed once the implementation is revamped
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct, this has been removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest abstracting all this logic away in a "Label Selector" class, similar to what the Example Selectors do.
The Label Selector would be initialized once in the run/plan agent method by reading appropriate fields from the config (again, very similar to example selector)
The agent can then call this object's "select labels" function when labeling each example to get a list of K most likely labels, like https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/labeler.py#L183
And pass it to the task object (like https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/labeler.py#L189)
label_examples = [{"input": label} for label in labels_list] | ||
|
||
example_selector = SemanticSimilarityExampleSelector.from_examples( | ||
# This is the list of labels available to select from. | ||
label_examples, | ||
# This is the embedding class used to produce embeddings which are used to measure semantic similarity. | ||
OpenAIEmbeddings(), | ||
# This is the VectorStore class that is used to store the embeddings and do a similarity search over. | ||
VectorStoreWrapper(cache=False), | ||
# This is the number of examples to produce. | ||
k=10, | ||
) | ||
similar_prompt = FewShotPromptTemplate( | ||
example_selector=example_selector, | ||
example_prompt=example_prompt, | ||
prefix="Input: {example}", | ||
suffix="", | ||
input_variables=["example"], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please revamp this implementation.
- No need to go via FewShotExampleTemplate and semantic similarity example selector.
- The label selection should conceptually consist of 3 steps: (i) input row --> formatted example (ii) compute embedding of the formatted example (iii) find nearest neighbors from among the label list
- The embeddings for labels in the label list should be computed just once, not once per row
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have revamped the implementation.
The FewShotExampleTemplate and semantic similarity example selector has been removed entirely.
Embeddings for labels in the label list is computed only once, in agent.plan() and agent.run()
# This is the list of labels available to select from. | ||
label_examples, | ||
# This is the embedding class used to produce embeddings which are used to measure semantic similarity. | ||
OpenAIEmbeddings(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to read this from the embedding model section in Autolabel config
@@ -55,6 +60,40 @@ def construct_prompt(self, input: Dict, examples: List) -> str: | |||
# prepare task guideline | |||
labels_list = self.config.labels_list() | |||
num_labels = len(labels_list) | |||
|
|||
# if large number of labels, filter labels_list by similarity of labels to input | |||
if num_labels >= 50: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with previous comments. we should enable this "label selection" from a config parameter, not a hardcoded num_labels threshold
split_lines = sampled_labels.split("\n") | ||
labels_list = [] | ||
for i in range(1, len(split_lines)): | ||
if split_lines[i]: | ||
labels_list.append(split_lines[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be needed once the implementation is revamped
Worth noting that after refactoring this PR, I am noticing a slight (~5%) impact on labeling accuracy on Ledgar dataset Results with label_selection = true, k = 10 ┏━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ Results with label_selection = false ┏━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ Prior to the revamp, accuracy was about equal in both cases. Perhaps our cos_sim() function isn't quite as good as the langchain similarity selector I was using previously? That or the embedding function is configured differently. I have noticed that embedding generation time is longer now than it was prior. |
How I tested:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
src/autolabel/labeler.py
Outdated
@@ -346,7 +376,13 @@ def plan( | |||
) | |||
else: | |||
examples = [] | |||
final_prompt = self.task.construct_prompt(input_i, examples) | |||
if self.config.label_selection(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this would give an error if label selection was set to true for any task other than classification. This is because the construct_prompt has been changed just for the classification task. Any way to catch this i.e label selection not supported for this task
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. Will check that it is a classification task (if label_selection = true)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in 2a6ec29
@@ -21,6 +21,11 @@ | |||
|
|||
import json | |||
|
|||
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably don't need these imports now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed these imports in e5ff12a
…es having OPENAI_API_KEY when importing autolabel
@iomap to followup with any updates to documentation |
New option to filter labels_list by similarity to input example.
Two new optional fields are now present in prompt_config schema: